package flyway

import (
	"errors"
	"fmt"
	"gitee.com/kristas/booting-go/framework/common/util/cryptox"
	"gitee.com/kristas/booting-go/framework/common/util/lang"
	"gitee.com/kristas/booting-go/framework/common/util/sort"
	"gitee.com/kristas/booting-go/framework/core/configure"
	"gitee.com/kristas/booting-go/framework/core/log"
	"gitee.com/kristas/booting-go/framework/core/statement/types"
	"gorm.io/gorm"
	"os"
	"path/filepath"
	"regexp"
	"time"
)

type Configure struct {
	types.AutoConfigure   `prefix:"application.flyway"`
	Enable                bool   `yaml:"enable"`
	Locations             string `yaml:"locations"`
	Table                 string `yaml:"table"`
	SqlMigrationPrefix    string `yaml:"sql_migration_prefix"`
	SqlMigrationSuffixes  string `yaml:"sql_migration_suffixes"`
	SqlMigrationSeparator string `yaml:"sql_migration_separator"`
}

type Migrator struct {
	db      *gorm.DB
	Conf    *Configure
	scripts []FlywaySchemaHistory
	Log     log.Logger `wire:""`
}

func NewGormMigrator(db *gorm.DB) *Migrator {
	c := &Configure{
		Enable:                true,
		Locations:             "migration",
		Table:                 "flyway_schema_history",
		SqlMigrationPrefix:    "V",
		SqlMigrationSuffixes:  ".sql",
		SqlMigrationSeparator: "__",
	}
	err := configure.BindConfiguration(c)
	if err != nil {
		panic(err)
	}
	return &Migrator{
		db:      db,
		Conf:    c,
		scripts: make([]FlywaySchemaHistory, 0),
	}
}

var flow = []func(*Migrator) error{
	createSchemaTable,
	prepareScripts,
	sortScripts,
	executeScripts,
}

func (r *Migrator) Execute() error {
	if r.Conf.Enable {
		for i := range flow {
			err := flow[i](r)
			if err != nil {
				return err
			}
		}
	}
	return nil
}

func createSchemaTable(m *Migrator) error {
	err := m.db.Exec(fmt.Sprintf(flywaySchemaHistorySql, m.Conf.Table)).Error
	if err != nil {
		panic(err)
	}
	return nil
}

func prepareScripts(m *Migrator) error {
	compile, _ := regexp.Compile(fmt.Sprintf("%s(\\d)+([_\\.]\\d+)*%s[A-Za-z0-9\\-_]+(%s)$", m.Conf.SqlMigrationPrefix, m.Conf.SqlMigrationSeparator, m.Conf.SqlMigrationSuffixes))
	return filepath.Walk(m.Conf.Locations, func(path string, info os.FileInfo, err error) error {
		if !info.IsDir() {
			match := compile.MatchString(info.Name())
			if match {
				script := lang.NewString(info.Name()).Split(m.Conf.SqlMigrationSeparator)
				m.scripts = append(m.scripts, FlywaySchemaHistory{
					Version:     script[0].RemovePrefix(m.Conf.SqlMigrationPrefix).String(),
					Description: script[1].RemoveSuffix(m.Conf.SqlMigrationSuffixes).String(),
					Type:        "SQL",
					Script:      info.Name(),
				})
			}
		}
		return err
	})
}

func sortScripts(m *Migrator) error {
	sort.Slice(m.scripts, func(i, j int) bool {
		return m.scripts[i].Version < m.scripts[j].Version
	})
	for i := range m.scripts {
		m.scripts[i].InstalledRank = i + 1
	}
	return nil
}

func executeScripts(m *Migrator) error {
	latestVersion := getLatestSchemaRecordVersion(m) //获取最新记录版本
	for i := range m.scripts {
		history := m.scripts[i]
		file := readScriptFiles(history, m.Conf.Locations) //读取脚本文件
		history.Checksum = cryptox.MD5(file)
		checkHistoryScript(m, history) //检查文件是否变更
		if latestVersion >= history.Version {
			continue
		}
		err := m.db.Transaction(func(tx *gorm.DB) (err error) {
			start := time.Now()
			if err = executeScriptFile(file, tx); err != nil {
				return
			}
			history.ExecutionTime = int(time.Now().Sub(start) / 1e6)
			history.Success = true
			if err = tx.Table(m.Conf.Table).Create(history).Error; err != nil {
				return
			}
			return
		})
		if err != nil {
			panic("fly failed")
		}
	}
	return nil
}

func getLatestSchemaRecordVersion(m *Migrator) string {
	var version string
	m.db.Table(m.Conf.Table).Select("version").Order("version desc").Limit(1).Find(&version)
	return version
}

func readScriptFiles(history FlywaySchemaHistory, locations string) []byte {
	file, err := os.ReadFile(fmt.Sprintf("%s/%s", locations, history.Script))
	if err != nil {
		panic(err)
	}
	return file
}

func checkHistoryScript(m *Migrator, history FlywaySchemaHistory) {
	var record FlywaySchemaHistory
	m.db.Table(m.Conf.Table).Where("script = ?", history.Script).Find(&record)
	if record.Success && record.Checksum != history.Checksum {
		panic(errors.New(fmt.Sprintf("scripts %s has been changed", history.Script)))
	}
}

func executeScriptFile(file []byte, tx *gorm.DB) (err error) {
	scripts := lang.NewString(string(file)).Split(";")
	for j := range scripts {
		sql := scripts[j].Trims()
		if sql.IsEmpty() {
			continue
		}
		job := sql.Concat(";").String()
		if err = tx.Exec(job).Error; err != nil {
			return
		}
	}
	return
}
